/*
 * Copyright (c) Hisilicon Technologies Co., Ltd. 2022-2022. All rights reserved.
 * Description: general ir quantize sever
 */

#include "omg/quantize_optimizer/general_ir_quantize_saver.h"

#include <list>
#include <algorithm>

#include "graph/types.h"
#include "graph/op/all_ops.h"

#include "infra/base/assertion.h"
#include "infra/base/securestl.h"

#include "framework/graph/core/edge/endpoint.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_walker.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/utils/node_utils.h"
#include "framework/graph/op/internal_defs.h"

#include "omg/quantize_optimizer/quantize_util.h"

using namespace std;
using namespace ge;

namespace hiai {

enum class PositionType {
    POS_INPUT,
    POS_OUTPUT,
};

namespace {
inline void SetWeightQuantParams(
    ge::Node& filterNode, const vector<float>& weightScale, const vector<float>& weightOffset)
{
    OpDesc& filterOpDesc = filterNode.ROLE(NodeSpec).OpDesc();
    filterOpDesc.SetType(hiai::op::QuantizedConst::TYPE);
    filterOpDesc.SetAttr(hiai::op::QuantizedConst::scale, AttrValue::CreateFrom(weightScale));
    filterOpDesc.SetAttr(hiai::op::QuantizedConst::offset, AttrValue::CreateFrom(weightOffset));
}

Status CheckSharedWeightQuantParams(ge::Node& node, const vector<float>& weightScale)
{
    OpDesc& opDesc = node.ROLE(NodeSpec).OpDesc();
    vector<float> currScales;
    HIAI_EXPECT_TRUE(AttrUtils::GetListFloat(opDesc, hiai::op::QuantizedConst::scale, currScales));

    HIAI_EXPECT_TRUE(currScales.size() == weightScale.size());
    for (size_t i = 0; i < currScales.size(); i++) {
        if (fabs(currScales[i] - weightScale[i]) > ZERO_EPS) {
            FMK_LOGE("Shared weight has different weight scales, op:%s.", opDesc.GetName().c_str());
            return hiai::FAILURE;
        }
    }

    return hiai::SUCCESS;
}

bool SetScaleAndOffsetAttr(const OperatorParam& operatorParam, ge::OpDescPtr& quantizeOpDesc)
{
    HIAI_EXPECT_TRUE_R(
        AttrUtils::SetListFloat(*quantizeOpDesc, hiai::op::QuantizeV2::scale, operatorParam.scale), false);
    return AttrUtils::SetListFloat(*quantizeOpDesc, hiai::op::QuantizeV2::offset, operatorParam.offset);
}

bool SetScaleOffsetAndDtypeAttr(const OperatorParam& operatorParam, ge::OpDescPtr& quantizeOpDesc)
{
    HIAI_EXPECT_TRUE_R(SetScaleAndOffsetAttr(operatorParam, quantizeOpDesc), false);
    return AttrUtils::SetInt(*quantizeOpDesc, hiai::op::QuantizeV2::dtype, operatorParam.dataType);
}

bool SetDeqScaleAttr(const OperatorParam& operatorParam, ge::OpDescPtr& quantizeOpDesc)
{
    return AttrUtils::SetListFloat(*quantizeOpDesc, hiai::op::DequantizeV2::deq_scale, operatorParam.scale);
}

Status ProcessConstInput(Node& srcNode, const OperatorParam& operatorParam, int64_t weightDataAddr)
{
    if (srcNode.ROLE(NodeSpec).Type() == hiai::op::QuantizedConst::TYPE) {
        HIAI_EXPECT_EXEC(CheckSharedWeightQuantParams(srcNode, operatorParam.scale));
        return hiai::SUCCESS;
    } else {
        SetWeightQuantParams(srcNode, operatorParam.scale, operatorParam.offset);
    }
    return QuantizeUtil::QuantizeWeight(srcNode, operatorParam.dataType, operatorParam.scale, false, weightDataAddr);
}

ge::OpDescPtr CreateQuantizeOpDesc(
    const string& name, PositionType posType, uint32_t index, const OperatorParam& operatorParam)
{
    static const std::map<OperatorType, std::string> operatorTypeMap {
        { OperatorType::QUANTIZE, std::string(hiai::op::QuantizeV2::TYPE) },
        { OperatorType::DEQUANTIZE, std::string(hiai::op::DequantizeV2::TYPE) },
        { OperatorType::REQUANTIZE, std::string(hiai::op::Requantize::TYPE) },
        { OperatorType::ANTIQUANTIZE, std::string(hiai::op::AntiQuantize::TYPE) } };

    auto it = operatorTypeMap.find(operatorParam.operType);
    HIAI_EXPECT_TRUE_R(it != operatorTypeMap.end(), nullptr);

    std::string operatorType = it->second;
    string posTypeStr = (posType == PositionType::POS_INPUT) ? "input" : "output";
    string opName = name + "_" + posTypeStr + "_" + to_string(index) + "_" +
        to_string(operatorParam.operIndex) + "_" + operatorType;
    ge::OpDescPtr quantizeOpDesc = hiai::make_shared_nothrow<ge::OpDesc>(opName, operatorType);
    HIAI_EXPECT_NOT_NULL_R(quantizeOpDesc, nullptr);

    TensorDesc tmpTensor;
    quantizeOpDesc->AddInputDesc(tmpTensor);
    quantizeOpDesc->AddOutputDesc(tmpTensor);

    switch (operatorParam.operType) {
        case OperatorType::QUANTIZE:
            HIAI_EXPECT_TRUE_R(SetScaleOffsetAndDtypeAttr(operatorParam, quantizeOpDesc), nullptr);
            break;
        case OperatorType::DEQUANTIZE:
            HIAI_EXPECT_TRUE_R(SetDeqScaleAttr(operatorParam, quantizeOpDesc), nullptr);
            break;
        case OperatorType::REQUANTIZE:
            HIAI_EXPECT_TRUE_R(SetScaleOffsetAndDtypeAttr(operatorParam, quantizeOpDesc), nullptr);
            break;
        case OperatorType::ANTIQUANTIZE:
            HIAI_EXPECT_TRUE_R(SetScaleAndOffsetAttr(operatorParam, quantizeOpDesc), nullptr);
            break;
        default:
            FMK_LOGE("Operator type:%d not supported currently.", operatorParam.operType);
            return nullptr;
    }

    return quantizeOpDesc;
}

Status InsertNodesToGraph(
    vector<ge::OpDescPtr>& insertOpDescs, uint32_t index, PositionType posType, Node& targetNode)
{
    ge::ComputeGraph& ownerGraph = targetNode.ROLE(NodeSpec).OwnerComputeGraph();
    vector<ge::Node*> insertNodes;
    for (auto quantizeOpDesc : insertOpDescs) {
        ge::Node* quantizeNode = ownerGraph.ROLE(GraphModifier).AddNode(quantizeOpDesc);
        HIAI_EXPECT_NOT_NULL_R(quantizeNode, hiai::PARAM_INVALID);
        insertNodes.push_back(quantizeNode);
    }
    for (size_t i = 1; i < insertNodes.size(); i++) {
        HIAI_EXPECT_TRUE(ownerGraph.ROLE(GraphModifier).AddEdge(
            Endpoint(*insertNodes[i - 1], 0), Endpoint(*insertNodes[i], 0)) == hiai::SUCCESS);
    }
    if (posType == PositionType::POS_INPUT) {
        auto inDataEdge = targetNode.ROLE(NodeWalker).InDataEdge(index);
        HIAI_EXPECT_TRUE(inDataEdge.Exist());
        Endpoint srcOut(inDataEdge->SrcNode(), inDataEdge->SrcIdx());
        HIAI_EXPECT_TRUE(
            ownerGraph.ROLE(GraphModifier).RemoveEdge(inDataEdge->SrcNode(), targetNode) == hiai::SUCCESS);
        HIAI_EXPECT_TRUE(ownerGraph.ROLE(GraphModifier).AddEdge(
            srcOut, Endpoint(*insertNodes[0], 0)) == hiai::SUCCESS);
        HIAI_EXPECT_TRUE(ownerGraph.ROLE(GraphModifier).AddEdge(
            Endpoint(*insertNodes[insertNodes.size() - 1], 0), Endpoint(targetNode, index)) == hiai::SUCCESS);
    } else {
        vector<Endpoint> dstEndPoints;
        auto relinkOutputDataNode = [&dstEndPoints, &ownerGraph](Edge& outputDataEdge) {
            dstEndPoints.push_back(Endpoint(outputDataEdge.DstNode(), outputDataEdge.DstIdx()));
            return ownerGraph.ROLE(GraphModifier).RemoveEdge(outputDataEdge);
        };
        HIAI_EXPECT_TRUE(
            targetNode.ROLE(NodeWalker).ListOutDataEdges(index, std::move(relinkOutputDataNode)) == hiai::SUCCESS);

        HIAI_EXPECT_TRUE(ownerGraph.ROLE(GraphModifier).AddEdge(
            Endpoint(targetNode, index), Endpoint(*insertNodes[0], 0)) == hiai::SUCCESS);
        for (size_t i = 0; i < dstEndPoints.size(); i++) {
            HIAI_EXPECT_TRUE(ownerGraph.ROLE(GraphModifier).AddEdge(
                Endpoint(*insertNodes[insertNodes.size() - 1], 0), dstEndPoints[i]) == hiai::SUCCESS);
        }
    }

    return hiai::SUCCESS;
}

void AdjustQuantizeOffsets(ge::DataType userDataType, OperatorParam& operParam)
{
    std::map<std::pair<ge::DataType, ge::DataType>, float> offsetValues = {
        {std::make_pair(ge::DT_INT8, ge::DT_UINT8), -128.0f},
        {std::make_pair(ge::DT_UINT8, ge::DT_INT8), 128.0f},
        {std::make_pair(ge::DT_INT16, ge::DT_UINT16), -32768.0f},
        {std::make_pair(ge::DT_UINT16, ge::DT_INT16), 32768.0f}
    };
    if (offsetValues.find(std::make_pair(userDataType, operParam.dataType)) == offsetValues.end()) {
        return;
    }
    float offsetValue = offsetValues[std::make_pair(userDataType, operParam.dataType)];
    for (size_t j = 0; j < operParam.offset.size(); j++) {
        operParam.offset[j] += offsetValue;
    }
    operParam.dataType = userDataType;
}
const char* const TENSOR_ATTR_QUANTIZED_DATA_TYPE = "quantized_data_type";
std::list<ge::DataType> INPUT_OUTPUT_QUANTIZE_DATATYPES = {ge::DT_UINT8, ge::DT_INT8, ge::DT_UINT16, ge::DT_INT16};
Status ProcessQuantizedDataInput(Node& srcNode, Node& targetNode, QuantizeParams& params)
{
    size_t i = 0;
    for (; i < params.operatorParams.size(); i++) {
        if (params.operatorParams[i].operIndex == 0) {
            break;
        }
    }
    ge::TensorDescPtr inputTensor = srcNode.ROLE(NodeSpec).OpDesc().MutableInputDesc(0);
    HIAI_EXPECT_NOT_NULL(inputTensor);
    ge::DataType dataInputType = inputTensor->GetDataType();
    if (std::find(INPUT_OUTPUT_QUANTIZE_DATATYPES.begin(), INPUT_OUTPUT_QUANTIZE_DATATYPES.end(), dataInputType) ==
        INPUT_OUTPUT_QUANTIZE_DATATYPES.end()) {
        return hiai::SUCCESS;
    }
    bool isQuanatizedDataType = false;
    if (inputTensor->HasAttr(TENSOR_ATTR_QUANTIZED_DATA_TYPE)) {
        ge::AttrUtils::GetBool(*inputTensor, TENSOR_ATTR_QUANTIZED_DATA_TYPE, isQuanatizedDataType);
    }
    if (!isQuanatizedDataType) {
        return hiai::SUCCESS; // 输入的整形类型非量化类型
    }

    if (i >= params.operatorParams.size() || params.operatorParams[i].operType != OperatorType::QUANTIZE) {
        FMK_LOGE("Input quant oper index or operType is illegal.");
        return FAILURE;
    }
    AdjustQuantizeOffsets(dataInputType, params.operatorParams[i]);

    OperatorParam antiQuantizeParam {0, OperatorType::ANTIQUANTIZE,
        params.operatorParams[i].dataType, params.operatorParams[i].scale, params.operatorParams[i].offset};

    ge::OpDescPtr antiQuantizeOpDesc = CreateQuantizeOpDesc(targetNode.ROLE(NodeSpec).Name(),
        PositionType::POS_INPUT, params.index, antiQuantizeParam);
    HIAI_EXPECT_NOT_NULL(antiQuantizeOpDesc);
    vector<ge::OpDescPtr> insertOpDescs = {antiQuantizeOpDesc};
    if (InsertNodesToGraph(insertOpDescs, 0, PositionType::POS_INPUT, targetNode) != SUCCESS) {
        FMK_LOGE("Insert antiquantize node after Data node fail.");
        return FAILURE;
    }
    return SUCCESS;
}


Status ProcessQuantizedOutput(ge::Edge& outEdge, QuantizeParams& params)
{
    ge::TensorDescPtr outputTensor = outEdge.DstNode().ROLE(NodeSpec).OpDesc().MutableOutputDesc(outEdge.DstIdx());
    HIAI_EXPECT_NOT_NULL(outputTensor);
    ge::DataType outDataType = outputTensor->GetDataType();
    if (std::find(INPUT_OUTPUT_QUANTIZE_DATATYPES.begin(), INPUT_OUTPUT_QUANTIZE_DATATYPES.end(), outDataType) ==
        INPUT_OUTPUT_QUANTIZE_DATATYPES.end()) {
        return hiai::SUCCESS;
    }
    bool isQuanatizedDataType = false;
    if (outputTensor->HasAttr(TENSOR_ATTR_QUANTIZED_DATA_TYPE)) {
        ge::AttrUtils::GetBool(*outputTensor, TENSOR_ATTR_QUANTIZED_DATA_TYPE, isQuanatizedDataType);
        (void)outputTensor->DelAttr(TENSOR_ATTR_QUANTIZED_DATA_TYPE);
    }
    if (!isQuanatizedDataType) {
        return hiai::SUCCESS; // 输出的整形类型非量化类型
    }
    auto it = params.operatorParams.begin();
    while (it != params.operatorParams.end()) {
        if (it->operIndex == params.operatorParams.size() - 1) {
            break;
        }
        it++;
    }
    // 对于指定量化输出的网络，最后一层插入的量化算子只能是Quant或Requant，保证输出为量化类型
    if (it == params.operatorParams.end() || it->operType != OperatorType::ANTIQUANTIZE) {
        FMK_LOGE("OperIndex is out of operator nums or operType is illegal.");
        return FAILURE;
    }
    params.operatorParams.erase(it);
    size_t i = 0;
    for (; i < params.operatorParams.size(); i++) {
        if (params.operatorParams[i].operIndex == (params.operatorParams.size() - 1)) {
            break;
        }
    }
    if (i >= params.operatorParams.size() || ((params.operatorParams[i].operType != OperatorType::QUANTIZE) &&
        (params.operatorParams[i].operType != OperatorType::REQUANTIZE))) {
        FMK_LOGE("Input quant oper index is illegal or operType is illegal.");
        return FAILURE;
    }
    AdjustQuantizeOffsets(outDataType, params.operatorParams[i]);
    return SUCCESS;
}

Status InsertQuantizeNodes(const QuantizeParams& params, Node& targetNode, PositionType posType,
    int64_t weightDataAddr)
{
    QuantizeParams unifiedParams = params;
    if (posType == PositionType::POS_INPUT) {
        ge::Node* srcNode = targetNode.ROLE(NodeWalker).InDataNode(params.index);
        if (srcNode == nullptr && params.operatorParams[0].dataType == ge::DT_INT32) {
            FMK_LOGW(
                "Current node:%s has no bias weight, but has quant params.", targetNode.ROLE(NodeSpec).Name().c_str());
            return hiai::SUCCESS;
        }
        HIAI_EXPECT_NOT_NULL(srcNode);
        if (srcNode->ROLE(NodeSpec).IsConstOp()) {
            HIAI_EXPECT_TRUE(params.operatorParams.size() == 1);
            return ProcessConstInput(*srcNode, params.operatorParams[0], weightDataAddr);
        }
        if (srcNode->ROLE(NodeSpec).IsDataOp()) {
            HIAI_EXPECT_TRUE(ProcessQuantizedDataInput(*srcNode, targetNode, unifiedParams) == SUCCESS);
        }
    } else if (posType == PositionType::POS_OUTPUT) {
        HIAI_EXPECT_TRUE(targetNode.ROLE(NodeWalker).ListOutDataEdges(params.index,
            [&unifiedParams](ge::Edge& outEdge) {
            if (outEdge.DstNode().ROLE(NodeSpec).IsNetOutputOp()) {
                HIAI_EXPECT_TRUE(ProcessQuantizedOutput(outEdge, unifiedParams) == SUCCESS);
            }
            return hiai::SUCCESS;
        }) == hiai::SUCCESS);
    }

    vector<ge::OpDescPtr> insertOpDescs(unifiedParams.operatorParams.size(), nullptr);
    for (size_t i = 0; i < unifiedParams.operatorParams.size(); i++) {
        ge::OpDescPtr quantizeOpDesc = CreateQuantizeOpDesc(
            targetNode.ROLE(NodeSpec).Name(), posType, unifiedParams.index, unifiedParams.operatorParams[i]);
        HIAI_EXPECT_NOT_NULL_R(quantizeOpDesc, hiai::PARAM_INVALID);
        insertOpDescs[unifiedParams.operatorParams[i].operIndex] = quantizeOpDesc;
    }
    return InsertNodesToGraph(insertOpDescs, unifiedParams.index, posType, targetNode);
}

Status ProcessQuantParams(ge::Node& targetNode, const std::vector<QuantizeParams>& quantParams, PositionType posType,
    int64_t weightDataAddr = 0)
{
    for (vector<QuantizeParams>::const_iterator it = quantParams.cbegin(); it != quantParams.cend(); it++) {
        HIAI_EXPECT_EXEC(InsertQuantizeNodes(*it, targetNode, posType, weightDataAddr));
    }
    return hiai::SUCCESS;
}
} // namespace

Status GeneralIRQuantizeSaver::SaveOpQuantParams(
    const QuantizeV2Config& quantizeConfig, Node* node, int64_t weightDataAddr)
{
    HIAI_EXPECT_EXEC(
        ProcessQuantParams(*node, quantizeConfig.inputQuantParams, PositionType::POS_INPUT, weightDataAddr));
    HIAI_EXPECT_EXEC(ProcessQuantParams(*node, quantizeConfig.outputQuantParams, PositionType::POS_OUTPUT));

    if (quantizeConfig.hasQuantInfoExt) {
        HIAI_EXPECT_EXEC(QuantizeUtil::SetQuantInfoExt(node->ROLE(NodeSpec).OpDesc(), quantizeConfig.quantInfoExt));
    }

    return hiai::SUCCESS;
}
} // namespace hiai
